classdef SkillAndParticipation < handle
    % Combines skill distribution and participation cost
    % distribution.
    
    properties
        dens_skill
        dens_part
        util

        scale_inc % income in the model scaled down by this factor
                  % (compared to inc. in the data)
    end
    
    properties (SetAccess = protected)
        num_nodes_f
        num_nodes_g
        num_nodes_bins
        
        nodes_f
        % pre-computed nodes for integration over :math:`f_i(w)`
        weights_f
        % pre-computed weights for integration over :math:`f_i(w)`
        nodes_bins
        % pre-computed nodes for integration within bins of :math:`f_i(w)`
        weights_bins
        % pre-computed weights for integration within bins of :math:`f_i(w)`
    end
    
    methods
        function self = SkillAndParticipation(varargin)
            p = inputParser;
            p.KeepUnmatched = true;
            
            valid_num_nodes = @(x) isnumeric(x) && x==floor(x) && ...
                (x>0);
            
            p.addParameter('dens_skill', [])
            p.addParameter('dens_part', [])
            p.addParameter('util', [])
            
            p.addParameter('num_nodes_f', 10, valid_num_nodes);
            p.addParameter('num_nodes_g', 15, valid_num_nodes);
            p.addParameter('num_nodes_bins', 10, valid_num_nodes);
            p.addParameter('scale_inc', 1, @(x) x>0);
            
            p.parse(varargin{:});
            
            self.dens_skill = p.Results.dens_skill;
            self.dens_part = p.Results.dens_part;
            self.util = p.Results.util;
            self.num_nodes_f = p.Results.num_nodes_f;
            self.num_nodes_g = p.Results.num_nodes_g;
            self.num_nodes_bins = p.Results.num_nodes_bins;
            self.scale_inc = p.Results.scale_inc;
            [self.nodes_f, self.weights_f] = tools.Integration.lgwt(self.num_nodes_f,-1,1);
            [self.nodes_bins, self.weights_bins] = tools.Integration.lgwt(self.num_nodes_bins,-1,1);
        end
        
        function set_num_nodes(self, num_nodes_f, num_nodes_g,num_nodes_bins,~)
            
            valid_num_nodes = @(x) isnumeric(x) && ...
                x==floor(x) && (x>0);
            
            p = inputParser;
            p.addRequired('num_nodes_f', valid_num_nodes);
            p.addRequired('num_nodes_g', valid_num_nodes);
            p.addRequired('num_nodes_bins', valid_num_nodes);
            
            p.parse(num_nodes_f, num_nodes_g, num_nodes_bins);
            
            self.num_nodes_f = p.Results.num_nodes_f;
            self.num_nodes_g = p.Results.num_nodes_g;
            self.num_nodes_bins = p.Results.num_nodes_bins;
            
            % re-compute nodes and weights
            [self.nodes_f, self.weights_f] = tools.Integration.lgwt(self.num_nodes_f,-1,1);
            [self.nodes_bins, self.weights_bins] = tools.Integration.lgwt(self.num_nodes_bins,-1,1);
            
            % set nodes and weights for dens_skill object
            try
                self.dens_skill.set_num_nodes(num_nodes_f, num_nodes_g, ...
                    num_nodes_bins);
            end
        end
        
        function out = average_wage_M(self, Y_M, Y_R, Y_C, w_lb, w_ub, mass_participants, par)
            integrand = @(w) w.*self.h_M(w,Y_M,Y_R,Y_C,par);
            emp_M = self.emp_M(Y_M, Y_R, Y_C, w_lb, w_ub, mass_participants, par);
            out = 1/emp_M * 1/mass_participants * ...
                tools.Integration.integral_legendre(integrand,w_lb,w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function out = average_wage_R(self, Y_M, Y_R, Y_C, w_lb, w_ub, mass_participants, par)
            integrand = @(w) w.*self.h_R(w,Y_M,Y_R,Y_C,par);
            emp_R = self.emp_R(Y_M, Y_R, Y_C, w_lb, w_ub, mass_participants, par);
            out = 1/emp_R * 1/mass_participants * ...
                tools.Integration.integral_legendre(integrand,w_lb,w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function out = average_wage_C(self, Y_M, Y_R, Y_C, w_lb, w_ub, mass_participants, par)
            integrand = @(w) w.*self.h_C(w,Y_M,Y_R,Y_C,par);
            emp_C = self.emp_C(Y_M, Y_R, Y_C, w_lb, w_ub, mass_participants, par);
            
            out = 1/emp_C * 1/mass_participants * ...
                tools.Integration.integral_legendre(integrand,w_lb,w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function out = average_wage(self, Y_M, Y_R, Y_C, w_lb, w_ub, mass_participants, par)
            integrand = @(w) w.*self.h(w,Y_M,Y_R,Y_C,par);
            out = 1/mass_participants * ...
                tools.Integration.integral_legendre(integrand,w_lb,w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function out = emp_M(self, Y_M, Y_R, Y_C, w_lb, w_ub, mass_participants, par)
            handle = @(w) self.h_M(w,Y_M,Y_R,Y_C,par);
            out = 1/mass_participants * ...
                tools.Integration.integral_legendre(handle,w_lb,w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function out = emp_R(self, Y_M, Y_R, Y_C, w_lb, w_ub, mass_participants, par)
            handle = @(w) self.h_R(w,Y_M,Y_R,Y_C,par);
            out = 1/mass_participants * ...
                tools.Integration.integral_legendre(handle,w_lb,w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function out = emp_C(self, Y_M, Y_R, Y_C, w_lb, w_ub, mass_participants, par)
            handle = @(w) self.h_C(w,Y_M,Y_R,Y_C,par);
            out = 1/mass_participants * ...
                tools.Integration.integral_legendre(handle,w_lb,w_ub,self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        function mass = total_mass_wagebins(self, Y_M, Y_R, Y_C, wagebins_M, wagebins_R, wagebins_C,mass_participants,par)
            mass_bins_h_M = h_M_mass_wagebins(self, Y_M, Y_R, Y_C, wagebins_M, ...
                mass_participants, par);
            mass_bins_h_R = h_R_mass_wagebins(self, Y_M, Y_R, Y_C, wagebins_R, ...
                mass_participants, ...
                par);
            mass_bins_h_C = h_C_mass_wagebins(self, Y_M, Y_R, Y_C, wagebins_C, ...
                mass_participants, ...
                par);
            mass = sum(mass_bins_h_M) + sum(mass_bins_h_R) + sum(mass_bins_h_C);
        end
        
        function mass = h_M_mass_wagebins(self, Y_M, Y_R, Y_C, wagebins,mass_participants, par)
            w_lb = wagebins(1);
            w_ub = wagebins(end);
            
            fct_handle = @(w) self.h_M(w, Y_M, Y_R, Y_C, par);
            mass = 1/mass_participants * ...
                tools.Integration.mass_bins(fct_handle, wagebins, ...
                self.num_nodes_bins, self.nodes_bins, self.weights_bins);
        end
        
        function mass = h_R_mass_wagebins(self, Y_M, Y_R, Y_C, wagebins,mass_participants, par)
            w_lb = wagebins(1);
            w_ub = wagebins(end);
            
            fct_handle = @(w) self.h_R(w, Y_M, Y_R, Y_C, par);
            mass = 1/mass_participants * ...
                tools.Integration.mass_bins(fct_handle, wagebins, ...
                self.num_nodes_bins, self.nodes_bins, self.weights_bins);
        end
        
        function mass = h_C_mass_wagebins(self, Y_M, Y_R, Y_C, wagebins,mass_participants, par)
            w_lb = wagebins(1);
            w_ub = wagebins(end);
            
            fct_handle = @(w) self.h_C(w, Y_M, Y_R, Y_C, par);
            mass = 1/mass_participants * ...
                tools.Integration.mass_bins(fct_handle, wagebins, ...
                self.num_nodes_bins, self.nodes_bins, self.weights_bins);
        end
        
        function mass = h_mass_wagebins(self, Y_M, Y_R, Y_C, wagebins, mass_participants, par)
            w_lb = wagebins(1);
            w_ub = wagebins(end);
            
            fct_handle = @(w) self.h(w, Y_M, Y_R, Y_C, par);
            mass = 1/mass_participants * ...
                tools.Integration.mass_bins(fct_handle, wagebins, ...
                self.num_nodes_bins, self.nodes_bins, self.weights_bins);
        end
        
        function mass = mass_participants(self, Y_M, Y_R, Y_C, w_lb, w_ub, par)
            fct_handle = @(w) self.h(w, Y_M, Y_R, Y_C, par);
            mass = tools.Integration.integral_legendre(fct_handle, w_lb, w_ub, self.num_nodes_f,self.nodes_f,self.weights_f);
            if mass==0
                warning('participation rate is zero')
            end
        end
        
        function out = h(self, w, Y_M, Y_R, Y_C, par)
            out = h_M(self, w, Y_M, Y_R, Y_C, par) + ...
                h_R(self, w, Y_M, Y_R, Y_C, par) + ...
                h_C(self, w, Y_M, Y_R, Y_C, par);
        end
        
        function out = h_M(self, w, Y_M, Y_R, Y_C, par)
            out = self.participation_M(w, par) ...
                .* self.dens_skill.f_M( w, Y_M, Y_R, Y_C, par);
        end
        
        function out = h_R(self, w, Y_M, Y_R, Y_C, par)
            out = self.participation_R(w, par) ...
                .* self.dens_skill.f_R( w, Y_M, Y_R, Y_C, par);
        end
        
        function out = h_C(self, w, Y_M, Y_R, Y_C, par)
            out = self.participation_C(w, par) ...
                .* self.dens_skill.f_C( w, Y_M, Y_R, Y_C, par);
        end
        
        % could simplify to having just one participation()
        % function, since the same for M, R, C (at w)
        function out = participation_M(self, w, par)
            out = self.dens_part.G_M(self.phi_tilde(w, par), par);
        end
        
        function out = participation_R(self, w, par)
            out = self.dens_part.G_R(self.phi_tilde(w, par), par);
        end
        
        function out = participation_C(self, w, par)
            out = self.dens_part.G_C(self.phi_tilde(w, par), par);
        end
        
        function out = total_labor_income(self, Y_M, Y_R, Y_C, w_lb,  w_ub, par)
            epsilon = self.util.epsilon;
            
            integrand = @(w) self.y_hsv(w, epsilon, par.lambda_hsv, par.tau_hsv, par.xi) .*...
                self.h(w, Y_M, Y_R, Y_C, par);
                
            out = tools.Integration.integral_legendre(integrand, w_lb, w_ub, self.num_nodes_f,self.nodes_f,self.weights_f);
        end
        
        
        % for derivations of tax-dependent expressions see HeathcoteTaxes2.nb
        
        function out = tax_hsv_y(self, y, lambda, tau)
        % y here is unscaled
            out = y - lambda * y.^(1 - tau);
        end
        
        function out = tax_hsv_unscaled(self, w, epsilon, lambda, tau, xi)
            y_unscaled = self.l_hsv(w, epsilon, lambda, tau, xi) .* w .* ...
                self.scale_inc;
            out = self.tax_hsv_y(y_unscaled, lambda, tau);
        end

        function out = tax_hsv(self, w, epsilon, lambda, tau, xi)
        % scaled, in USD/scale_inc
            out = 1./self.scale_inc.*self.tax_hsv_unscaled(w, epsilon, lambda, tau, xi);
        end
        
        function out = mtax_hsv_y(self, y, lambda, tau)
        % y here is unscaled
            out = 1 - (1 - tau) .* lambda * y.^(- tau);
        end
        
        function out = mtax_hsv(self, w, epsilon, lambda, tau, xi)
             y_unscaled = self.l_hsv(w, epsilon, lambda, tau, xi) .* w .* ...
                self.scale_inc;
            out = self.mtax_hsv_y(y_unscaled, lambda, tau);
        end
        
        function out = l_hsv(self, w, epsilon, lambda, tau, xi)
        % takes care of scaling down income by scale_inc, such that the
        % model is in terms of USD/scale_inc
            out = (lambda.*self.scale_inc.^((-1).*tau).*(1+(-1).*tau).*xi.*w.^(1+(-1).* ...
  tau)).^(epsilon.*(1+epsilon.*tau).^(-1));
        end
        
        function out = y_hsv(self, w, epsilon, lambda, tau, xi)
        % scaled, in USD/scale_inc
            out = l_hsv(self, w, epsilon, lambda, tau, xi).*w;
        end
        
        function out = phi_tilde(self, w, par)
        % below is a command to print a warning whenever phi_tilde
        % is computed for a given tax system. Only for debugging.
        % warning('computing phi_tilde for given tax system')
            transfer = par.transfer;
            xi = par.xi;
            
            tau = par.tau_hsv;
            lambda = par.lambda_hsv;
            epsilon = self.util.epsilon;
            
            labor = self.l_hsv(w, epsilon, lambda, tau, xi);
            y = labor .* w; % in USD/scale_inc 
            tax = self.tax_hsv(w, epsilon, lambda, tau, xi);
            consumption = y - tax;
            
            V_w = self.util.utility(consumption, labor, xi);
            V_0 = self.util.utility(transfer, 0, xi);
            out = V_w - V_0;
        end
        
    end
    
end